{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fdd7ff16",
   "metadata": {},
   "source": [
    "# T5. trainer 和 evaluator 的深入介绍\n",
    "\n",
    "&emsp; 1 &ensp; fastNLP 中 driver 的补充介绍\n",
    " \n",
    "&emsp; &emsp; 1.1 &ensp; trainer 和 driver 的构想 \n",
    "\n",
    "&emsp; &emsp; 1.2 &ensp; device 与 多卡训练\n",
    "\n",
    "&emsp; 2 &ensp; fastNLP 中的更多 metric 类型\n",
    "\n",
    "&emsp; &emsp; 2.1 &ensp; 预定义的 metric 类型\n",
    "\n",
    "&emsp; &emsp; 2.2 &ensp; 自定义的 metric 类型\n",
    "\n",
    "&emsp; 3 &ensp; fastNLP 中 trainer 的补充介绍\n",
    "\n",
    "&emsp; &emsp; 3.1 &ensp; trainer 的内部结构"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08752c5a",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 1. fastNLP 中 driver 的补充介绍\n",
    "\n",
    "### 1.1  trainer 和 driver 的构想\n",
    "\n",
    "在`fastNLP 0.8`中，模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**，\n",
    "\n",
    "&emsp; 在`tutorial 0`中，已经简单介绍过上述三个模块：**`driver`用来控制训练评测中的`model`的最终运行**\n",
    "\n",
    "&emsp; &emsp; **`evaluator`封装评测的`metric`**，**`trainer`封装训练的`optimizer`**，**也可以包括`evaluator`**\n",
    "\n",
    "之所以做出上述的划分，其根本目的在于要**达成对于多个`python`学习框架**，**例如`pytorch`、`paddle`、`jittor`的兼容**\n",
    "\n",
    "&emsp; 对于训练环节，其伪代码如下方左边紫色一栏所示，由于**不同框架对模型、损失、张量的定义各有不同**，所以将训练环节\n",
    "\n",
    "&emsp; &emsp; 划分为**框架无关的循环控制、批量分发部分**，**由`trainer`模块负责**实现，对应的伪代码如下方中间蓝色一栏所示\n",
    "\n",
    "&emsp; &emsp; 以及**随框架不同的模型调用、数值优化部分**，**由`driver`模块负责**实现，对应的伪代码如下方右边红色一栏所示\n",
    "\n",
    "| <div align=\"center\">训练过程</div> | <div align=\"center\">框架无关 对应`trainer`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
    "|:--|:--|:--|\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for epoch in 1:n_eoochs:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for epoch in 1:n_eoochs:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">for step in 1:total_steps:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:60px;\">batch = fetch_batch()</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss = model.forward(batch)&emsp;</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">loss.backward()</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">loss.backward()</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.clear_grad()</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.clear_grad()</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.update()</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.update()</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">if need_save:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">if need_save:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:60px;\">model.save()</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:60px;\">model.save()</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> |  |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e55f07b",
   "metadata": {},
   "source": [
    "&emsp; 对于评测环节，其伪代码如下方左边紫色一栏所示，同样由于不同框架对模型、损失、张量的定义各有不同，所以将评测环节\n",
    "\n",
    "&emsp; &emsp; 划分为**框架无关的循环控制、分发汇总部分**，**由`evaluator`模块负责**实现，对应的伪代码如下方中间蓝色一栏所示\n",
    "\n",
    "&emsp; &emsp; 以及**随框架不同的模型调用、评测计算部分**，同样**由`driver`模块负责**实现，对应的伪代码如下方右边红色一栏所示\n",
    "\n",
    "| <div align=\"center\">评测过程</div> | <div align=\"center\">框架无关 对应`evaluator`</div> | <div align=\"center\">框架相关 对应`driver`</div> |\n",
    "|:--|:--|:--|\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">try:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">try:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">model.set_eval()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">model.set_eval()</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">for step in 1:total_steps:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">for step in 1:total_steps:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">batch = fetch_batch()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:40px;\">batch = fetch_batch()</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">outputs = model.evaluate(batch)&emsp;</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:40px;\">metric.compute(batch, outputs)</div> |  | <div style=\"font-family:Consolas;font-weight:bold;color:red;text-indent:40px;\">metric.compute(batch, outputs)</div> |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">results = metric.get_metric()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">results = metric.get_metric()</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;\">except:</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;\">except:</div> |  |\n",
    "| <div style=\"font-family:Consolas;font-weight:bold;color:purple;text-indent:20px;\">process_exception()</div> | <div style=\"font-family:Consolas;font-weight:bold;color:blue;text-indent:20px;\">process_exception()</div> |  |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94ba11c6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "由此，从程序员的角度，`fastNLP v0.8`**通过一个`driver`让基于`pytorch`、`paddle`、`jittor`框架的模型**\n",
    "\n",
    "&emsp; &emsp; **都能在相同的`trainer`和`evaluator`上运行**，这也**是`fastNLP v0.8`相比于之前版本的一大亮点**\n",
    "\n",
    "&emsp; 而从`driver`的角度，`fastNLP v0.8`通过定义一个`driver`基类，**将所有张量转化为`numpy.tensor`**\n",
    "\n",
    "&emsp; &emsp; 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类，从而实现了\n",
    "\n",
    "&emsp; &emsp; 对`pytorch`、`paddle`、`jittor`的兼容，有关后两者的实践请参考接下来的`tutorial-6`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab1cea7d",
   "metadata": {},
   "source": [
    "### 1.2  device 与 多卡训练\n",
    "\n",
    "**`fastNLP v0.8`支持多卡训练**，实现方法则是**通过将`trainer`中的`device`设置为对应显卡的序号列表**\n",
    "\n",
    "&emsp; 由单卡切换成多卡，无论是数据、模型还是评测都会面临一定的调整，`fastNLP v0.8`保证：\n",
    "\n",
    "&emsp; &emsp; 数据拆分时，不同卡之间相互协调，所有数据都可以被训练，且不会使用到相同的数据\n",
    "\n",
    "&emsp; &emsp; 模型训练时，模型之间需要交换梯度；评测计算时，每张卡先各自计算，再汇总结果\n",
    "\n",
    "&emsp; 例如，在评测计算运行`get_metric`函数时，`fastNLP v0.8`将自动按照`self.right`和`self.total`\n",
    "\n",
    "&emsp; &emsp; 指定的**`aggregate_method`方法**，默认为`sum`，将每张卡上结果汇总起来，因此最终\n",
    "\n",
    "&emsp; &emsp; 在调用`get_metric`方法时，`Accuracy`类能够返回全部的统计结果，代码如下\n",
    "    \n",
    "```python\n",
    "trainer = Trainer(\n",
    "        model=model,                                # model 基于 pytorch 实现 \n",
    "        train_dataloader=train_dataloader,\n",
    "        optimizers=optimizer,\n",
    "        ...\n",
    "        driver='torch',                             # driver 使用 torch_driver \n",
    "        device=[0, 1],                              # gpu 选择 cuda:0 + cuda:1\n",
    "        ...\n",
    "        evaluate_dataloaders=evaluate_dataloader,\n",
    "        metrics={'acc': Accuracy()},\n",
    "        ...\n",
    "    )\n",
    "\n",
    "class Accuracy(Metric):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.register_element(name='total', value=0, aggregate_method='sum')\n",
    "        self.register_element(name='right', value=0, aggregate_method='sum')\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2e0a210",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "注：`fastNLP v0.8`中要求`jupyter`不能多卡，仅能单卡，故在所有`tutorial`中均不作相关演示"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d19220c",
   "metadata": {},
   "source": [
    "## 2. fastNLP 中的更多 metric 类型\n",
    "\n",
    "### 2.1  预定义的 metric 类型\n",
    "\n",
    "在`fastNLP 0.8`中，除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**，还有其他**预定义的评测标准`metric`**\n",
    "\n",
    "&emsp; 包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
    "\n",
    "&emsp; &emsp; **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**（其中也包括召回率`Pre`、精确率`Rec`\n",
    "\n",
    "&emsp; &emsp; **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**；相关基本信息内容见下表，之后是详细分析\n",
    "\n",
    "| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n",
    "|:--|:--|:--|\n",
    "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
    "| `Accuracy` | 正确率，最为常用 | `/core/metrics/accuracy.py` |\n",
    "| `TransformersAccuracy` | 正确率，为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
    "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值，适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
    "| `SpanFPreRecMetric` | 召回率、精确率、F1值，适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdc083a3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "&emsp; 如`tutorial-0`中所述，所有的`metric`都包含`get_metric`和`update`函数，其中\n",
    "\n",
    "&emsp; &emsp; **`update`函数更新单个`batch`的统计量**，**`get_metric`函数返回最终结果**，并打印显示\n",
    "\n",
    "\n",
    "### 2.1.1  Accuracy 与 TransformersAccuracy\n",
    "\n",
    "`Accuracy`，正确率，预测正确的数据`right_num`在总数据`total_num`，中的占比（公式就不用列了\n",
    "\n",
    "&emsp; `get_metric`函数打印格式为 **`{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}`**\n",
    "\n",
    "&emsp; 一般在初始化时不需要传参，`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
    "\n",
    "&emsp; **`update`函数的参数包括`pred`、`target`、`seq_len`**，**后者用来标记批次中每笔数据的长度**\n",
    "\n",
    "`TransformersAccuracy`，继承自`Accuracy`，只是为了兼容`Transformers`框架中相关模型\n",
    "\n",
    "&emsp; 在`update`函数中，将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
    "\n",
    "\n",
    "### 2.1.2  ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
    "\n",
    "`ClassifyFPreRecMetric`，分类评价，`SpanFPreRecMetric`，抽取评价，后者在`tutorial-4`中已出现\n",
    "\n",
    "&emsp; 两者的相同之处在于：**第一**，**都包括召回率/查全率`Rec`**、**精确率/查准率`Pre`**、**`F1`值**这三个指标\n",
    "\n",
    "&emsp; &emsp; `get_metric`函数打印格式为 **`{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}`**\n",
    "\n",
    "&emsp; &emsp; 三者的计算公式如下，其中`beta`默认为`1`，即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
    "\n",
    "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
    "\n",
    "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
    "\n",
    "&emsp; **第二**，可以通过参数`only_gross`为`False`，要求返回所有类别的`Rec-Pre-F1`，同时`F1`值又根据参数`f_type`又分为\n",
    "\n",
    "&emsp; &emsp; **`micro F1`**（**直接统计所有类别的`Rec-Pre-F1`**）、**`macro F1`**（**统计各类别的`Rec-Pre-F1`再算术平均**）\n",
    "\n",
    "&emsp; **第三**，两者在初始化时还可以**传入基于`fastNLP.Vocabulary`的`tag_vocab`参数记录数据集中的标签序号**\n",
    "\n",
    "&emsp; &emsp; **与标签名称之间的映射**，通过字符串列表`ignore_labels`参数，指定若干标签不用于`Rec-Pre-F1`的计算\n",
    "\n",
    "两者的不同之处在于：`ClassifyFPreRecMetric`针对简单的分类问题，每个分类标签之间彼此独立，不构成标签对\n",
    "\n",
    "&emsp; &emsp; **`SpanFPreRecMetric`针对更复杂的抽取问题**，**规定标签`B-xx`和`I-xx`或`B-xx`和`E-xx`构成标签对**\n",
    "\n",
    "&emsp; 在计算`Rec-Pre-F1`时，`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了，但是\n",
    "\n",
    "&emsp; &emsp; 对于`SpanFPreRecMetric`，需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
    "\n",
    "&emsp; &emsp; 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务，如果评测方法选择`ClassifyFPreRecMetric`\n",
    "\n",
    "&emsp; &emsp; &emsp; 或者`Accuracy`，会发现虽然评测结果显示很高，这是因为选择的评测方法要求太低\n",
    "\n",
    "&emsp; &emsp; 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
    "\n",
    "```python\n",
    "from fastNLP import Vocabulary\n",
    "from fastNLP import ClassifyFPreRecMetric\n",
    "\n",
    "tag_vocab = Vocabulary(padding=None, unknown=None)            # 记录序号与标签之间的映射\n",
    "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
    "                        'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
    "                        'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
    "                        'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
    "                        'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ])  # CoNLL-2003 中的 pos_tags\n",
    "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
    "\n",
    "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab,          \n",
    "                                ignore_labels=ignore_labels,  # 表示评测/优化中不考虑上述标签的正误/损失\n",
    "                                only_gross=True,              # 默认为 True 表示输出所有类别的综合统计结果\n",
    "                                f_type='micro')               # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
    "metrics = {'F1': FPreRec}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a22f522",
   "metadata": {},
   "source": [
    "### 2.2  自定义的 metric 类型\n",
    "\n",
    "如上文所述，`Metric`作为所有`metric`的基类，`Accuracy`等都是其子类，同样地，对于**自定义的`metric`类型**\n",
    "\n",
    "&emsp; &emsp; 也**需要继承自`Metric`类**，同时**内部自定义好`__init__`、`update`和`get_metric`函数**\n",
    "\n",
    "&emsp; 在`__init__`函数中，根据需求定义评测时需要用到的变量，此处沿用`Accuracy`中的`total_num`和`right_num`\n",
    "\n",
    "&emsp; 在`update`函数中，根据需求定义评测变量的更新方式，需要注意的是如`tutorial-0`中所述，**`update`的参数名**\n",
    "\n",
    "&emsp; &emsp; **需要待评估模型在`evaluate_step`中的输出名称一致**，由此**和数据集中对应字段名称一致**，即**参数匹配**\n",
    "\n",
    "&emsp; &emsp; 在`fastNLP v0.8`中，`update`函数的默认输入参数：`pred`，对应预测值；`target`，对应真实值\n",
    "\n",
    "&emsp; &emsp; 此处仍然沿用，因为接下来会需要使用`fastNLP`函数的与定义模型，其输入参数格式即使如此\n",
    "\n",
    "&emsp; 在`get_metric`函数中，根据需求定义评测指标最终的计算，此处直接计算准确率，该函数必须返回一个字典\n",
    "\n",
    "&emsp; &emsp; 其中，字串`'prefix'`表示该`metric`的名称，会对应显示到`trainer`的`progress bar`中\n",
    "\n",
    "根据上述要求，这里简单定义了一个名为`MyMetric`的评测模块，用于分类问题的评测，以此展开一个实例展示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "08a872e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from fastNLP import Metric\n",
    "\n",
    "class MyMetric(Metric):\n",
    "\n",
    "    def __init__(self):\n",
    "        Metric.__init__(self)\n",
    "        self.total_num = 0\n",
    "        self.right_num = 0\n",
    "\n",
    "    def update(self, pred, target):\n",
    "        self.total_num += target.size(0)\n",
    "        self.right_num += target.eq(pred).sum().item()\n",
    "\n",
    "    def get_metric(self, reset=True):\n",
    "        acc = self.right_num / self.total_num\n",
    "        if reset:\n",
    "            self.total_num = 0\n",
    "            self.right_num = 0\n",
    "        return {'prefix': acc}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0155f447",
   "metadata": {},
   "source": [
    "&emsp; 数据使用方面，此处仍然使用`datasets`模块中的`load_dataset`函数，加载`SST-2`二分类数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5ad81ac7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef923b90b19847f4916cccda5d33fc36",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "sst2data = load_dataset('glue', 'sst2')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d81760",
   "metadata": {},
   "source": [
    "&emsp; 在数据预处理中，需要注意的是，这里原本应该根据`metric`和`model`的输入参数格式，调整\n",
    "\n",
    "&emsp; &emsp; 数据集中表示预测目标的字段，调整为`target`，在后文中会揭晓为什么，以及如何补救"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cfb28b1b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/6000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from fastNLP import DataSet\n",
    "\n",
    "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
    "\n",
    "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n",
    "dataset.delete_field('sentence')\n",
    "dataset.delete_field('idx')\n",
    "\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.from_dataset(dataset, field_name='words')\n",
    "vocab.index_dataset(dataset, field_name='words')\n",
    "\n",
    "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
    "\n",
    "from fastNLP import prepare_torch_dataloader\n",
    "\n",
    "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
    "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af3f8c63",
   "metadata": {},
   "source": [
    "&emsp; 模型使用方面，此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型，实现`SST-2`二分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2fd210c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP.models.torch import CNNText\n",
    "\n",
    "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
    "\n",
    "from torch.optim import AdamW\n",
    "\n",
    "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e723b87",
   "metadata": {},
   "source": [
    "## 3. fastNLP 中 trainer 的补充介绍\n",
    "\n",
    "### 3.1  trainer 的内部结构\n",
    "\n",
    "在`tutorial-0`中，我们已经介绍了`trainer`的基本使用，从`tutorial-1`到`tutorial-4`，我们也已经展示了\n",
    "\n",
    "&emsp; 很多`trainer`的使用案例，这里通过表格，相对完整地介绍`trainer`模块的属性和初始化参数（标粗为必选参数\n",
    "\n",
    "| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
    "|:--|:--:|:--:|:--|:--|\n",
    "| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定，如`torch.nn.Module` |\n",
    "| `device` | √ |  | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
    "|  |  | √ | 记录`trainer`运行的卡位 | `Device`类型，在初始化阶段生成 |\n",
    "| **`driver`** | √ |  | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
    "|  |  | √ | 记录`trainer`驱动的框架 | `Driver`类型，在初始化阶段生成 |\n",
    "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`，记录在`driver.n_epochs`中 |\n",
    "| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定，如`torch.optim.Adam` |\n",
    "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型，如`{'acc': Metric()}` |\n",
    "| `evaluator` |  | √ | 内置的`trainer`评测模块 | `Evaluator`类型，在初始化阶段生成 |\n",
    "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型，输出字典匹配`forward`输入参数 |\n",
    "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型，输出字典匹配`xx_step`输入参数 |\n",
    "| **`train_dataloader`** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型，生成视框架而定 |\n",
    "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型，生成视框架而定 |\n",
    "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型，默认为`model.train_step` |\n",
    "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型，默认为`model.evaluate_step` |\n",
    "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型，默认为`TrainBatchLoop.batch_step_fn` |\n",
    "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型，默认为`EvaluateBatchLoop.batch_step_fn` |\n",
    "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`，即每个批次都反向传播 |\n",
    "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次，相反`1`表示每个批次一次 |\n",
    "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
    "| `callbacks` | √ |  | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型，详见`tutorial-7` |\n",
    "| `callback_manager` |  | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型，详见`tutorial-7` |\n",
    "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型，详见`tutorial-7` |\n",
    "| `marker` | √ | √ | 标记`trainer`实例，辅助`callbacks`相关内容 | 字符串型，详见`tutorial-7` |\n",
    "| `trainer_state` |  | √ | 记录`trainer`状态，辅助`callbacks`相关内容 | `TrainerState`类型，详见`tutorial-7` |\n",
    "| `state` |  | √ | 记录`trainer`状态，辅助`callbacks`相关内容 | `State`类型，详见`tutorial-7` |\n",
    "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型，默认`False` |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e13ee08",
   "metadata": {},
   "source": [
    "其中，**`input_mapping`和`output_mapping`** 定义形式如下：输入字典形式的数据，根据参数匹配要求\n",
    "\n",
    "&emsp; 调整数据格式，这里就回应了前文未在数据集预处理时调整格式的问题，**总之参数匹配一定要求**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de96c1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def input_mapping(data):\n",
    "    data['target'] = data['label']\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fc8b9f3",
   "metadata": {},
   "source": [
    "&emsp; 而`trainer`模块的基础方法列表如下，相关进阶操作，如“`on`系列函数”、`callback`控制，请参考后续的`tutorial-7`\n",
    "\n",
    "| <div align=\"center\">名称</div> |<div align=\"center\">功能</div> | <div align=\"center\">主要参数</div> |\n",
    "|:--|:--|:--|\n",
    "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
    "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
    "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
    "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
    "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
    "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测，实际是否执行取决于评测频率 | 无输入 |\n",
    "| `step_evaluate` | 实现`trainer`训练中每个批次的评测，实际是否执行取决于评测频率 | 无输入 |\n",
    "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径，`only_state_dict`指明是否只保存状态字典，默认`False` |\n",
    "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径，`only_state_dict`指明是否只加载状态字典，默认`True` |\n",
    "| `save_checkpoint` | <div style=\"line-height:25px;\">保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar`</div> | `folder`指明路径，`only_state_dict`指明是否只保存状态字典，默认`True` |\n",
    "| `load_checkpoint` | <div style=\"line-height:25px;\">加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar`</div> | <div style=\"line-height:25px;\">`folder`指明路径，`only_state_dict`指明是否只保存状态字典，默认`True`<br>`resume_training`指明是否只精确到上次训练的批量，默认`True`</div> |\n",
    "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机，`fn`指明回调函数 |\n",
    "| `on` | 函数修饰器，将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
    "\n",
    "<!-- ```python\n",
    "Trainer.__init__():\n",
    "\ton_after_trainer_initialized(trainer, driver)\n",
    "Trainer.run():\n",
    "\tif num_eval_sanity_batch > 0:                              # 如果设置了 num_eval_sanity_batch\n",
    "\t\ton_sanity_check_begin(trainer)\n",
    "\t\ton_sanity_check_end(trainer, sanity_check_res)\n",
    "\ttry:\n",
    "\t\ton_train_begin(trainer)\n",
    "\t\twhile cur_epoch_idx < n_epochs:\n",
    "\t\t\ton_train_epoch_begin(trainer)\n",
    "\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n",
    "\t\t\t\ton_fetch_data_begin(trainer)\n",
    "\t\t\t\tbatch = next(dataloader)\n",
    "\t\t\t\ton_fetch_data_end(trainer)\n",
    "\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n",
    "\t\t\t\ton_before_backward(trainer, outputs)           # 其中 outputs 是经过 output_mapping 后的\n",
    "\t\t\t\ton_after_backward(trainer)\n",
    "\t\t\t\ton_before_zero_grad(trainer, optimizers)  # 实际调用受到 accumulation_steps 影响\n",
    "\t\t\t\ton_after_zero_grad(trainer, optimizers)  # 实际调用受到 accumulation_steps 影响\n",
    "\t\t\t\ton_before_optimizers_step(trainer, optimizers)  # 实际调用受到 accumulation_steps 影响\n",
    "\t\t\t\ton_after_optimizers_step(trainer, optimizers)  # 实际调用受到 accumulation_steps 影响\n",
    "\t\t\t\ton_train_batch_end(trainer)\n",
    "\t\t\ton_train_epoch_end(trainer)\n",
    "\texcept BaseException:\n",
    "\t\tself.on_exception(trainer, exception)\n",
    "\tfinally:\n",
    "\t\ton_train_end(trainer)\n",
    "``` -->"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e21df35",
   "metadata": {},
   "source": [
    "紧接着，初始化`trainer`实例，继续完成`SST-2`分类，其中`metrics`输入的键值对，字串`'suffix'`和之前定义的\n",
    "\n",
    "&emsp; 字串`'prefix'`将拼接在一起显示到`progress bar`中，故完整的输出形式为`{'prefix#suffix': float}`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "926a9c50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    driver='torch',\n",
    "    device=0,  # 'cuda'\n",
    "    n_epochs=10,\n",
    "    optimizers=optimizers,\n",
    "    input_mapping=input_mapping,\n",
    "    train_dataloader=train_dataloader,\n",
    "    evaluate_dataloaders=evaluate_dataloader,\n",
    "    metrics={'suffix': MyMetric()}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b2e8b7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "最后就是`run`函数的使用，关于其参数，这里也以表格形式列出，由此就解答了`num_eval_batch_per_dl=10`的含义\n",
    "\n",
    "| <div align=\"center\">名称</div> | <div align=\"center\">功能</div> | <div align=\"center\">默认值</div> |\n",
    "|:--|:--|:--|\n",
    "| `num_train_batch_per_epoch` | 指定`trainer`训练时，每个循环计算批量数目 | 整数类型，默认`-1`，表示训练时，每个循环计算所有批量 |\n",
    "| `num_eval_batch_per_dl` | 指定`trainer`评测时，每个循环计算批量数目 | 整数类型，默认`-1`，表示评测时，每个循环计算所有批量 |\n",
    "| `num_eval_sanity_batch` | 指定`trainer`训练开始前，试探性评测批量数目 | 整数类型，默认`2`，表示训练开始前评估两个批量 |\n",
    "| `resume_from` | 指定`trainer`恢复状态的路径，需要是文件夹 | 字符串型，默认`None`，使用可参考`CheckpointCallback` |\n",
    "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型，默认`True`恢复所有状态，`False`仅恢复`model`和`optimizers`状态 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "43be274f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:30:35] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches.              <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
       ".get_parent()\n",
       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6875</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.825</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n",
       "</pre>\n"
      ],
      "text/plain": [
       "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "  <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "  \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.run(num_eval_batch_per_dl=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1abfa0a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
